{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.0.0-alpha0\n",
      "GPU Available:  True\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from pathlib import Path\n",
    "import glob\n",
    "import pickle\n",
    "import json\n",
    "import matplotlib.pyplot as plt\n",
    "from lstm import LSTM_Simple\n",
    "from metrics import exact_match_metric\n",
    "from callbacks import NValidationSetsCallback, GradientLogger\n",
    "from generator import DataGenerator, DataGeneratorSeq\n",
    "from tqdm import tqdm\n",
    "\n",
    "from sklearn.model_selection import train_test_split\n",
    "import tensorflow as tf\n",
    "from tensorflow.keras.optimizers import Adam\n",
    "print(tf.__version__)\n",
    "print(\"GPU Available: \", tf.test.is_gpu_available())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define evaluation class"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LSTM_Evaluator:\n",
    "    \n",
    "    def __init__(self, path, ):\n",
    "        \n",
    "        with open(str(path/'settings.json'), 'r') as file:\n",
    "            self.settings_dict = json.load(file)\n",
    "            \n",
    "        if (path/'stoi.pkl').is_file():\n",
    "            with open(str(path/'stoi.pkl'), 'rb') as file:\n",
    "                self.token_index = pickle.load(file)        \n",
    "        else:\n",
    "            self.token_index = self.__get_stoi_from_data()\n",
    "        self.num_tokens = len(self.token_index)\n",
    "        \n",
    "        adam = Adam(lr=6e-4, beta_1=0.9, beta_2=0.995, epsilon=1e-9, decay=0.0, amsgrad=False, clipnorm=0.1)\n",
    "        self.lstm = LSTM_Simple(self.num_tokens, self.settings_dict['latent_dim'])\n",
    "        _ = self.lstm.get_model()\n",
    "        self.lstm.model.load_weights(str(path/'model.h5'))\n",
    "        self.lstm.model.compile(optimizer=adam, loss='categorical_crossentropy', metrics=[exact_match_metric])\n",
    "        \n",
    "    def evaluate_model(self, input_texts, output_texts, teacher_forcing=True, batch_size=128, n_samples=1000):\n",
    "        max_seq_length  = max([len(txt_in)+len(txt_out) for txt_in, txt_out in zip(input_texts,output_texts)])\n",
    "        \n",
    "        params = {'batch_size': batch_size,\n",
    "                  'max_seq_length': max_seq_length,\n",
    "                  'num_tokens': self.num_tokens,\n",
    "                  'token_index': self.token_index,\n",
    "                  'num_thinking_steps': self.settings_dict[\"thinking_steps\"]\n",
    "                 }\n",
    "        \n",
    "        self.data_generator = DataGeneratorSeq(input_texts=input_texts,\n",
    "                                               target_texts=output_texts,\n",
    "                                               **params)\n",
    "        \n",
    "        if not teacher_forcing:\n",
    "            outputs_true, outputs_preds = self.predict_without_teacher(n_samples, max_seq_length)\n",
    "            exact_match = len([0 for out_true, out_preds in zip(outputs_true, outputs_preds) if out_true.strip()==out_preds.strip()])/len(outputs_true)\n",
    "        \n",
    "        else:\n",
    "            result = self.lstm.model.evaluate_generator(self.data_generator, verbose=1)\n",
    "            exact_match = result[1]\n",
    "            \n",
    "        return exact_match\n",
    "    \n",
    "    def predict_on_string(self, text, max_output_length=100):\n",
    "        \n",
    "        max_seq_length = len(text) + max_output_length\n",
    "\n",
    "        \n",
    "        params = {'batch_size': 1,\n",
    "                  'max_seq_length': max_seq_length,\n",
    "                  'num_tokens': self.num_tokens,\n",
    "                  'token_index': self.token_index,\n",
    "                  'num_thinking_steps': self.settings_dict[\"thinking_steps\"]\n",
    "                 }\n",
    "        \n",
    "        \n",
    "        self.data_generator = DataGeneratorSeq(input_texts=[text],\n",
    "                                               target_texts=['0'*max_output_length],\n",
    "                                               **params)\n",
    "        \n",
    "        outputs_true, outputs_preds = self.predict_without_teacher(1, max_seq_length)\n",
    "        \n",
    "        return outputs_preds[0].strip()\n",
    "\n",
    "    def predict_without_teacher(self, n_samples, max_seq_length, random=True):\n",
    "        \n",
    "        encoded_texts = [] \n",
    "        outputs_true = []\n",
    "        if random:\n",
    "            samples = np.random.choice(self.data_generator.indexes, n_samples, replace=False)\n",
    "        else:\n",
    "            samples = list(range(n_samples))\n",
    "        for i in samples:\n",
    "            input_len = len(input_texts_train[i])\n",
    "            sample = self.data_generator._DataGeneratorSeq__data_generation([i])         \n",
    "            input_len = len(self.data_generator.input_texts[i])\n",
    "            outputs_true.append(self.data_generator.target_texts[i])\n",
    "            x = sample[0][0][:input_len+self.settings_dict[\"thinking_steps\"]+1]\n",
    "            encoded_texts.append(np.expand_dims(x, axis=0))\n",
    "            \n",
    "        outputs_preds = self.lstm.decode_sample(encoded_texts, self.token_index, max_seq_length)\n",
    "        return outputs_true, outputs_preds\n",
    "        \n",
    "        \n",
    "    def __get_stoi_from_data(self):\n",
    "\n",
    "        \"\"\"\n",
    "        This function reloads all the data that was used to train and evalute\n",
    "        model to construct the string to integer map (stoi).\n",
    "        \"\"\"\n",
    "        \n",
    "        def concatenate_texts(path, pattern):\n",
    "            file_paths = list(path.glob('{}*.txt'.format(pattern)))\n",
    "            input_texts = []\n",
    "            target_texts = []\n",
    "\n",
    "            for file_path in file_paths:\n",
    "                with open(str(file_path), 'r', encoding='utf-8') as f:\n",
    "                    lines = f.read().split('\\n')[:-1]\n",
    "\n",
    "                input_texts.extend(lines[0::2])\n",
    "                target_texts.extend(['\\t' + target_text + '\\n' for target_text in lines[1::2]])\n",
    "            return input_texts, target_texts\n",
    "        \n",
    "        raw_path = Path(self.settings_dict['data_path'])\n",
    "        interpolate_path = raw_path/'interpolate'\n",
    "        extrapolate_path = raw_path/'extrapolate'\n",
    "        train_easy_path = raw_path/'train-easy/'\n",
    "        math_module = settings_dict[\"math_module\"]\n",
    "        train_level = settings_dict[\"train_level\"]\n",
    "        datasets = {\n",
    "            'train':(raw_path, 'train-' + train_level + '/' + math_module),\n",
    "            'interpolate':(interpolate_path, math_module),\n",
    "            'extrapolate':(extrapolate_path, math_module)\n",
    "                   }\n",
    "\n",
    "        input_texts = {}\n",
    "        target_texts = {}\n",
    "\n",
    "        for k, v in datasets.items():\n",
    "            input_texts[k], target_texts[k] = concatenate_texts(v[0], v[1])\n",
    "        \n",
    "        all_input_texts = sum(input_texts.values(), [])\n",
    "        all_target_texts = sum(target_texts.values(), [])\n",
    "\n",
    "        input_characters = set(''.join(all_input_texts))\n",
    "        target_characters = set(''.join(all_target_texts))\n",
    "\n",
    "        tokens = sorted(list(input_characters | target_characters))\n",
    "        token_index = dict([(char, i) for i, char in enumerate(tokens)])\n",
    "        \n",
    "        return token_index\n",
    "        \n",
    "        \n",
    "        \n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_path = Path('../../models/js0kldpwp1nhos/')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'batch_size': 1024,\n",
       " 'data_path': '/storage/git/deep-math/data/raw/v1.0/',\n",
       " 'epochs': 1,\n",
       " 'latent_dim': 2048,\n",
       " 'math_module': 'arithmetic',\n",
       " 'save_path': '/artifacts/',\n",
       " 'saved_model': '/storage/artifacts/j4bu146wamlr9/model.h5',\n",
       " 'thinking_steps': 16,\n",
       " 'train_level': '*'}"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "settings_path = model_path/'settings.json'\n",
    "\n",
    "with open(str(settings_path), 'r') as file:\n",
    "    settings_dict = json.load(file)\n",
    "\n",
    "\n",
    "raw_path = Path(settings_dict['data_path'])\n",
    "interpolate_path = raw_path/'interpolate'\n",
    "extrapolate_path = raw_path/'extrapolate'\n",
    "train_easy_path = raw_path/'train-easy/'\n",
    "\n",
    "settings_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def concatenate_texts(path, pattern):\n",
    "    file_paths = list(path.glob('{}*.txt'.format(pattern)))\n",
    "    \n",
    "    input_texts = []\n",
    "    target_texts = []\n",
    "\n",
    "    for file_path in file_paths:\n",
    "        with open(str(file_path), 'r', encoding='utf-8') as f:\n",
    "            lines = f.read().split('\\n')[:-1]\n",
    "\n",
    "        input_texts.extend(lines[0::2])\n",
    "        target_texts.extend(['\\t' + target_text + '\\n' for target_text in lines[1::2]])\n",
    "        \n",
    "    return input_texts, target_texts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "def concatenate_texts_individual(path, pattern):\n",
    "    file_paths = list(path.glob('{}*.txt'.format(pattern)))\n",
    "    \n",
    "    input_texts = {}\n",
    "    target_texts = {}\n",
    "\n",
    "    for file_path in file_paths:\n",
    "        \n",
    "        input_texts[file_path] = []\n",
    "        target_texts[file_path] = []\n",
    "        \n",
    "        with open(str(file_path), 'r', encoding='utf-8') as f:\n",
    "            lines = f.read().split('\\n')[:-1]\n",
    "\n",
    "        input_texts[file_path].extend(lines[0::2])\n",
    "        target_texts[file_path].extend(['\\t' + target_text + '\\n' for target_text in lines[1::2]])\n",
    "        \n",
    "    return input_texts, target_texts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Length of set train is 17999982\n",
      "Length of set interpolate is 90000\n",
      "Length of set extrapolate is 60000\n"
     ]
    }
   ],
   "source": [
    "math_module = settings_dict[\"math_module\"]\n",
    "train_level = settings_dict[\"train_level\"]\n",
    "\n",
    "datasets = {\n",
    "    'train':(raw_path, 'train-' + train_level + '/' + math_module),\n",
    "    'interpolate':(interpolate_path, math_module),\n",
    "    'extrapolate':(extrapolate_path, math_module)\n",
    "           }\n",
    "\n",
    "input_texts = {}\n",
    "target_texts = {}\n",
    "\n",
    "for k, v in datasets.items():\n",
    "    input_texts[k], target_texts[k] = concatenate_texts(v[0], v[1])\n",
    "    print('Length of set {} is {}'.format(k, len(input_texts[k])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_texts_train, input_texts_valid, target_texts_train, target_texts_valid = train_test_split(input_texts['train'], target_texts['train'], test_size=0.2, random_state=42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of training samples: 14399985\n",
      "Number of validation samples: 3599997\n"
     ]
    }
   ],
   "source": [
    "print('Number of training samples:', len(input_texts_train))\n",
    "print('Number of validation samples:', len(input_texts_valid))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Data samples:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INPUT: What is 2 - (1 + -5) - 11?\n",
      "OUTPUT: -5\n"
     ]
    }
   ],
   "source": [
    "print('INPUT:', input_texts['train'][42])\n",
    "print('OUTPUT:', target_texts['train'][42].strip())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluate datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "lstm_eval = LSTM_Evaluator(model_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_sample=1024"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "8/8 [==============================] - 4s 492ms/step - loss: 0.0082 - exact_match_metric: 0.7285\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "0.7285156"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lstm_eval.evaluate_model(input_texts_train[:test_sample], target_texts_train[:test_sample])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "8/8 [==============================] - 4s 528ms/step - loss: 0.0078 - exact_match_metric: 0.7363\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "0.7363281"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lstm_eval.evaluate_model(input_texts_valid[:test_sample], target_texts_valid[:test_sample])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "8/8 [==============================] - 2s 229ms/step - loss: 0.0112 - exact_match_metric: 0.6592\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "0.6591797"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lstm_eval.evaluate_model(input_texts['interpolate'][:test_sample], target_texts['interpolate'][:test_sample])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "8/8 [==============================] - 2s 263ms/step - loss: 0.0609 - exact_match_metric: 0.2070\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "0.20703125"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lstm_eval.evaluate_model(input_texts['extrapolate'][:test_sample], target_texts['extrapolate'][:test_sample])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## We can also test an indiviual string:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'-4'"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lstm_eval.predict_on_string('1 / 7')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
